import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit
import numpy as np
import cv2
from PIL import Image
import yaml
from prettytable import PrettyTable
from tqdm import tqdm

import common
from utils.image import letterbox_image, get_n_hls_colors_v2
from utils.decode import decode_outputs, non_max_suppression, get_real_boxes


class Yolo(object):
    def __init__(self, engine_path, config, logger_severity, warmup_epoch=10):
        """
        YOLO类
        :param engine_path: engine文件路径
        :param config: config文件路径
        :param logger_severity: logger等级
        :param warmup_epoch: 预热轮次
        """
        self.logger = trt.Logger(logger_severity)    # 创建日志对象
        # 生成引擎以及开辟内存空间
        self.engine = self.get_engine(engine_path)
        self.context = self.engine.create_execution_context()
        self.inputs, self.outputs, self.binding, self.stream = common.allocate_buffers(self.engine)

        # 读取config文件
        with open(config, "r") as fp:
            self.config = yaml.safe_load(fp)

        self.num_classes = len(self.config["classes"])
        self.colors = get_n_hls_colors_v2(self.num_classes)    # 获取检测框颜色

        # 网络预热
        self.warmup(warmup_epoch)

    def get_engine(self, engine_path):
        # If a serialized engine exists, use it instead of building an engine.
        print("Reading engine from file {}".format(engine_path))
        with open(engine_path, "rb") as f, trt.Runtime(self.logger) as runtime:
            return runtime.deserialize_cuda_engine(f.read())

    def print_engine(self):
        table = PrettyTable(["binding name", "is input", "binding size", "binding shape", "dtype"])
        for binding in self.engine:
            dims = self.engine.get_binding_shape(binding)
            size = trt.volume(dims)
            is_input = self.engine.binding_is_input(binding)
            dtype = trt.nptype(self.engine.get_binding_dtype(binding))
            table.add_row([binding, is_input, size, str(dims), dtype])
        print(table)

    def warmup(self, epoch):
        print("start warm up!")
        t = np.random.random(self.config["input_shape"])
        np.copyto(self.inputs[0].host, t.reshape(-1))
        for i in tqdm(range(epoch)):
            common.do_inference_v2(self.context, self.binding, self.inputs, self.outputs, self.stream)

    def transform(self, img):
        """
        图像预处理
        :param img: 输入图像
        :return: 预处理后的图像
        """
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = letterbox_image(img, self.config["input_shape"][2:]).astype(np.float32) / 255.    # 给图像增加灰条并resize

        # 标准化图像
        img -= np.array(self.config["mean"])
        img /= np.array(self.config["std"])

        # shape: 640x640x3 -> 1x3x640x640
        img = np.transpose(img, [2, 0, 1])
        img = np.expand_dims(img, axis=0)
        return img

    def predict(self, img):
        """
        图像推理
        :param img: 输入图像
        :return: 推理结果
        """
        h, w, _ = img.shape

        img = self.transform(img)
        np.copyto(self.inputs[0].host, img.reshape(-1))    # 将图像拷贝到分配的内存当中
        result = common.do_inference_v2(self.context, self.binding, self.inputs, self.outputs, self.stream)  # 网络推理
        # 获取输出特征层
        for i in range(3):
            result[i] = np.reshape(result[i], newshape=[1, 5 + self.num_classes] + self.config["stage"][i])

        result = decode_outputs(result, self.config["input_shape"][2:])    # 解码结果
        result = non_max_suppression(result, self.num_classes, self.config["conf_thres"], self.config["iou_thres"]) # nms
        result = get_real_boxes(result, (w, h))    # 获取真实框
        return result

    def draw_bboxes(self, img, bboxes, thickness=1, print_info=True):
        """
        绘制框
        :param img: 原始图像
        :param bboxes: 推理结果
        :param thickness: 线框粗细
        :param print_info: 是否打印推理结果
        :return:
        """
        for box in bboxes:
            color = self.colors[int(box[-1])]    # 获取对应颜色
            img = cv2.rectangle(img, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), color, thickness)    # 绘制框
            conf = box[4] * box[5]    # 求解置信度
            classes = self.config["classes"][int(box[-1])]    # 获取label
            img = cv2.putText(img, "{} {:.2f}".format(classes, conf), (int(box[0]), int(box[1])), cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, thickness)    # 绘制种类

            if print_info:
                print("{} {:.2f} {} {} {} {}".format(classes, conf, int(box[0]), int(box[1]), int(box[2]), int(box[3])))
        return img

    def show_result(self, img, bboxes, thickness=1, print_info=True, show_type="PIL", title="result"):
        """
        展示检测结果
        :param img: 原始图像
        :param bboxes: 推理结果
        :param thickness: 线框粗细
        :param print_info: 是否打印推理结果
        :param show_type: 展示方式，可选择 PIL或cv2
        :param title: 图像标题
        :return: 无
        """
        img = self.draw_bboxes(img, bboxes, thickness, print_info)    # 绘制结果

        # 显示图像
        if show_type == "cv2":
            cv2.imshow(title, img)
            cv2.waitKey(0)
        elif show_type == "PIL":
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            img = Image.fromarray(img)
            img.show(title=title)
        else:
            raise KeyError("Please use cv2 or PIL.")
